{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "nmt_with_attention.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "metadata": {
        "colab_type": "text",
        "id": "AOpGoE2T-YXS"
      },
      "cell_type": "markdown",
      "source": [
        "##### Copyright 2018 The TensorFlow Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\").\n",
        "\n",
        "# Neural Machine Translation with Attention\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
        "<a target=\"_blank\"  href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>  \n",
        "</td><td>\n",
        "<a target=\"_blank\"  href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
      ]
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "CiwtNgENbx2g"
      },
      "cell_type": "markdown",
      "source": [
        "This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). This is an advanced example that assumes some knowledge of sequence to sequence models.\n",
        "\n",
        "After training the model in this notebook, you will be able to input a Spanish sentence, such as *\"¿todavia estan en casa?\"*, and return the English translation: *\"are you still at home?\"*\n",
        "\n",
        "The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n",
        "\n",
        "<img src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\">\n",
        "\n",
        "Note: This example takes approximately 10 mintues to run on a single P100 GPU."
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "tnxXKDjq3jEL",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "from __future__ import absolute_import, division, print_function\n",
        "\n",
        "# Import TensorFlow >= 1.10 and enable eager execution\n",
        "import tensorflow as tf\n",
        "\n",
        "tf.enable_eager_execution()\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "import unicodedata\n",
        "import re\n",
        "import numpy as np\n",
        "import os\n",
        "import time\n",
        "\n",
        "print(tf.__version__)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "wfodePkj3jEa"
      },
      "cell_type": "markdown",
      "source": [
        "## Download and prepare the dataset\n",
        "\n",
        "We'll use a language dataset provided by http://www.manythings.org/anki/. This dataset contains language translation pairs in the format:\n",
        "\n",
        "```\n",
        "May I borrow this book?\t¿Puedo tomar prestado este libro?\n",
        "```\n",
        "\n",
        "There are a variety of languages available, but we'll use the English-Spanish dataset. For convenience, we've hosted a copy of this dataset on Google Cloud, but you can also download your own copy. After downloading the dataset, here are the steps we'll take to prepare the data:\n",
        "\n",
        "1. Add a *start* and *end* token to each sentence.\n",
        "2. Clean the sentences by removing special characters.\n",
        "3. Create a word index and reverse word index (dictionaries mapping from word → id and id → word).\n",
        "4. Pad each sentence to a maximum length."
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "kRVATYOgJs1b",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# Download the file\n",
        "path_to_zip = tf.keras.utils.get_file(\n",
        "    'spa-eng.zip', origin='https://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip', \n",
        "    extract=True)\n",
        "\n",
        "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "rd0jw-eC3jEh",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# Converts the unicode file to ascii\n",
        "def unicode_to_ascii(s):\n",
        "    return ''.join(c for c in unicodedata.normalize('NFD', s)\n",
        "        if unicodedata.category(c) != 'Mn')\n",
        "\n",
        "\n",
        "def preprocess_sentence(w):\n",
        "    w = unicode_to_ascii(w.lower().strip())\n",
        "    \n",
        "    # creating a space between a word and the punctuation following it\n",
        "    # eg: \"he is a boy.\" => \"he is a boy .\" \n",
        "    # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n",
        "    w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n",
        "    w = re.sub(r'[\" \"]+', \" \", w)\n",
        "    \n",
        "    # replacing everything with space except (a-z, A-Z, \".\", \"?\", \"!\", \",\")\n",
        "    w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n",
        "    \n",
        "    w = w.rstrip().strip()\n",
        "    \n",
        "    # adding a start and an end token to the sentence\n",
        "    # so that the model know when to start and stop predicting.\n",
        "    w = '<start> ' + w + ' <end>'\n",
        "    return w"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "OHn4Dct23jEm",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# 1. Remove the accents\n",
        "# 2. Clean the sentences\n",
        "# 3. Return word pairs in the format: [ENGLISH, SPANISH]\n",
        "def create_dataset(path, num_examples):\n",
        "    lines = open(path, encoding='UTF-8').read().strip().split('\\n')\n",
        "    \n",
        "    word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')]  for l in lines[:num_examples]]\n",
        "    \n",
        "    return word_pairs"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "9xbqO7Iie9bb",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# This class creates a word -> index mapping (e.g,. \"dad\" -> 5) and vice-versa \n",
        "# (e.g., 5 -> \"dad\") for each language,\n",
        "class LanguageIndex():\n",
        "  def __init__(self, lang):\n",
        "    self.lang = lang\n",
        "    self.word2idx = {}\n",
        "    self.idx2word = {}\n",
        "    self.vocab = set()\n",
        "    \n",
        "    self.create_index()\n",
        "    \n",
        "  def create_index(self):\n",
        "    for phrase in self.lang:\n",
        "      self.vocab.update(phrase.split(' '))\n",
        "    \n",
        "    self.vocab = sorted(self.vocab)\n",
        "    \n",
        "    self.word2idx['<pad>'] = 0\n",
        "    for index, word in enumerate(self.vocab):\n",
        "      self.word2idx[word] = index + 1\n",
        "    \n",
        "    for word, index in self.word2idx.items():\n",
        "      self.idx2word[index] = word"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "eAY9k49G3jE_",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def max_length(tensor):\n",
        "    return max(len(t) for t in tensor)\n",
        "\n",
        "\n",
        "def load_dataset(path, num_examples):\n",
        "    # creating cleaned input, output pairs\n",
        "    pairs = create_dataset(path, num_examples)\n",
        "\n",
        "    # index language using the class defined above    \n",
        "    inp_lang = LanguageIndex(sp for en, sp in pairs)\n",
        "    targ_lang = LanguageIndex(en for en, sp in pairs)\n",
        "    \n",
        "    # Vectorize the input and target languages\n",
        "    \n",
        "    # Spanish sentences\n",
        "    input_tensor = [[inp_lang.word2idx[s] for s in sp.split(' ')] for en, sp in pairs]\n",
        "    \n",
        "    # English sentences\n",
        "    target_tensor = [[targ_lang.word2idx[s] for s in en.split(' ')] for en, sp in pairs]\n",
        "    \n",
        "    # Calculate max_length of input and output tensor\n",
        "    # Here, we'll set those to the longest sentence in the dataset\n",
        "    max_length_inp, max_length_tar = max_length(input_tensor), max_length(target_tensor)\n",
        "    \n",
        "    # Padding the input and output tensor to the maximum length\n",
        "    input_tensor = tf.keras.preprocessing.sequence.pad_sequences(input_tensor, \n",
        "                                                                 maxlen=max_length_inp,\n",
        "                                                                 padding='post')\n",
        "    \n",
        "    target_tensor = tf.keras.preprocessing.sequence.pad_sequences(target_tensor, \n",
        "                                                                  maxlen=max_length_tar, \n",
        "                                                                  padding='post')\n",
        "    \n",
        "    return input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_tar"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "GOi42V79Ydlr"
      },
      "cell_type": "markdown",
      "source": [
        "### Limit the size of the dataset to experiment faster (optional)\n",
        "\n",
        "Training on the complete dataset of >100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):"
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "cnxC7q-j3jFD",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# Try experimenting with the size of that dataset\n",
        "num_examples = 30000\n",
        "input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_targ = load_dataset(path_to_file, num_examples)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "4QILQkOs3jFG",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# Creating training and validation sets using an 80-20 split\n",
        "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n",
        "\n",
        "# Show length\n",
        "len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "rgCLkfv5uO3d"
      },
      "cell_type": "markdown",
      "source": [
        "### Create a tf.data dataset"
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "TqHsArVZ3jFS",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "BUFFER_SIZE = len(input_tensor_train)\n",
        "BATCH_SIZE = 64\n",
        "N_BATCH = BUFFER_SIZE//BATCH_SIZE\n",
        "embedding_dim = 256\n",
        "units = 1024\n",
        "vocab_inp_size = len(inp_lang.word2idx)\n",
        "vocab_tar_size = len(targ_lang.word2idx)\n",
        "\n",
        "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n",
        "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "TNfHIF71ulLu"
      },
      "cell_type": "markdown",
      "source": [
        "## Write the encoder and decoder model\n",
        "\n",
        "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://github.com/tensorflow/nmt). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://github.com/tensorflow/nmt#background-on-the-attention-mechanism) from the seq2seq tutorial. The following diagram shows that each input word is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n",
        "\n",
        "<img src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\">\n",
        "\n",
        "The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n",
        "\n",
        "Here are the equations that are implemented:\n",
        "\n",
        "<img src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\">\n",
        "<img src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\">\n",
        "\n",
        "We're using *Bahdanau attention*. Lets decide on notation before writing the simplified form:\n",
        "\n",
        "* FC = Fully connected (dense) layer\n",
        "* EO = Encoder output\n",
        "* H = hidden state\n",
        "* X = input to the decoder\n",
        "\n",
        "And the pseudo-code:\n",
        "\n",
        "* `score = FC(tanh(FC(EO) + FC(H)))`\n",
        "* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, 1)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.\n",
        "* `context vector = sum(attention weights * EO, axis = 1)`. Same reason as above for choosing axis as 1.\n",
        "* `embedding output` = The input to the decoder X is passed through an embedding layer.\n",
        "* `merged vector = concat(embedding output, context vector)`\n",
        "* This merged vector is then given to the GRU\n",
        "  \n",
        "The shapes of all the vectors at each step have been specified in the comments in the code:"
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "avyJ_4VIUoHb",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def gru(units):\n",
        "  # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)\n",
        "  # the code automatically does that.\n",
        "  if tf.test.is_gpu_available():\n",
        "    return tf.keras.layers.CuDNNGRU(units, \n",
        "                                    return_sequences=True, \n",
        "                                    return_state=True, \n",
        "                                    recurrent_initializer='glorot_uniform')\n",
        "  else:\n",
        "    return tf.keras.layers.GRU(units, \n",
        "                               return_sequences=True, \n",
        "                               return_state=True, \n",
        "                               recurrent_activation='sigmoid', \n",
        "                               recurrent_initializer='glorot_uniform')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "nZ2rI24i3jFg",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "class Encoder(tf.keras.Model):\n",
        "    def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n",
        "        super(Encoder, self).__init__()\n",
        "        self.batch_sz = batch_sz\n",
        "        self.enc_units = enc_units\n",
        "        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
        "        self.gru = gru(self.enc_units)\n",
        "        \n",
        "    def call(self, x, hidden):\n",
        "        x = self.embedding(x)\n",
        "        output, state = self.gru(x, initial_state = hidden)        \n",
        "        return output, state\n",
        "    \n",
        "    def initialize_hidden_state(self):\n",
        "        return tf.zeros((self.batch_sz, self.enc_units))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "yJ_B3mhW3jFk",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "class Decoder(tf.keras.Model):\n",
        "    def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n",
        "        super(Decoder, self).__init__()\n",
        "        self.batch_sz = batch_sz\n",
        "        self.dec_units = dec_units\n",
        "        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
        "        self.gru = gru(self.dec_units)\n",
        "        self.fc = tf.keras.layers.Dense(vocab_size)\n",
        "        \n",
        "        # used for attention\n",
        "        self.W1 = tf.keras.layers.Dense(self.dec_units)\n",
        "        self.W2 = tf.keras.layers.Dense(self.dec_units)\n",
        "        self.V = tf.keras.layers.Dense(1)\n",
        "        \n",
        "    def call(self, x, hidden, enc_output):\n",
        "        # enc_output shape == (batch_size, max_length, hidden_size)\n",
        "        \n",
        "        # hidden shape == (batch_size, hidden size)\n",
        "        # hidden_with_time_axis shape == (batch_size, 1, hidden size)\n",
        "        # we are doing this to perform addition to calculate the score\n",
        "        hidden_with_time_axis = tf.expand_dims(hidden, 1)\n",
        "        \n",
        "        # score shape == (batch_size, max_length, 1)\n",
        "        # we get 1 at the last axis because we are applying tanh(FC(EO) + FC(H)) to self.V\n",
        "        score = self.V(tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis)))\n",
        "        \n",
        "        # attention_weights shape == (batch_size, max_length, 1)\n",
        "        attention_weights = tf.nn.softmax(score, axis=1)\n",
        "        \n",
        "        # context_vector shape after sum == (batch_size, hidden_size)\n",
        "        context_vector = attention_weights * enc_output\n",
        "        context_vector = tf.reduce_sum(context_vector, axis=1)\n",
        "        \n",
        "        # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n",
        "        x = self.embedding(x)\n",
        "        \n",
        "        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n",
        "        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
        "        \n",
        "        # passing the concatenated vector to the GRU\n",
        "        output, state = self.gru(x)\n",
        "        \n",
        "        # output shape == (batch_size * 1, hidden_size)\n",
        "        output = tf.reshape(output, (-1, output.shape[2]))\n",
        "        \n",
        "        # output shape == (batch_size * 1, vocab)\n",
        "        x = self.fc(output)\n",
        "        \n",
        "        return x, state, attention_weights\n",
        "        \n",
        "    def initialize_hidden_state(self):\n",
        "        return tf.zeros((self.batch_sz, self.dec_units))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "P5UY8wko3jFp",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
        "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "_ch_71VbIRfK"
      },
      "cell_type": "markdown",
      "source": [
        "## Define the optimizer and the loss function"
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "WmTHr5iV3jFr",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "optimizer = tf.train.AdamOptimizer()\n",
        "\n",
        "\n",
        "def loss_function(real, pred):\n",
        "  mask = 1 - np.equal(real, 0)\n",
        "  loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n",
        "  return tf.reduce_mean(loss_)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "DMVWzzsfNl4e"
      },
      "cell_type": "markdown",
      "source": [
        "## Checkpoints (Object-based saving)"
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "Zj8bXQTgNwrF",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "checkpoint_dir = './training_checkpoints'\n",
        "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
        "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n",
        "                                 encoder=encoder,\n",
        "                                 decoder=decoder)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "hpObfY22IddU"
      },
      "cell_type": "markdown",
      "source": [
        "## Training\n",
        "\n",
        "1. Pass the *input* through the *encoder* which return *encoder output* and the *encoder hidden state*.\n",
        "2. The encoder output, encoder hidden state and the decoder input (which is the *start token*) is passed to the decoder.\n",
        "3. The decoder returns the *predictions* and the *decoder hidden state*.\n",
        "4. The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n",
        "5. Use *teacher forcing* to decide the next input to the decoder.\n",
        "6. *Teacher forcing* is the technique where the *target word* is passed as the *next input* to the decoder.\n",
        "7. The final step is to calculate the gradients and apply it to the optimizer and backpropagate."
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "ddefjBMa3jF0",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "EPOCHS = 10\n",
        "\n",
        "for epoch in range(EPOCHS):\n",
        "    start = time.time()\n",
        "    \n",
        "    hidden = encoder.initialize_hidden_state()\n",
        "    total_loss = 0\n",
        "    \n",
        "    for (batch, (inp, targ)) in enumerate(dataset):\n",
        "        loss = 0\n",
        "        \n",
        "        with tf.GradientTape() as tape:\n",
        "            enc_output, enc_hidden = encoder(inp, hidden)\n",
        "            \n",
        "            dec_hidden = enc_hidden\n",
        "            \n",
        "            dec_input = tf.expand_dims([targ_lang.word2idx['<start>']] * BATCH_SIZE, 1)       \n",
        "            \n",
        "            # Teacher forcing - feeding the target as the next input\n",
        "            for t in range(1, targ.shape[1]):\n",
        "                # passing enc_output to the decoder\n",
        "                predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)\n",
        "                \n",
        "                loss += loss_function(targ[:, t], predictions)\n",
        "                \n",
        "                # using teacher forcing\n",
        "                dec_input = tf.expand_dims(targ[:, t], 1)\n",
        "        \n",
        "        batch_loss = (loss / int(targ.shape[1]))\n",
        "        \n",
        "        total_loss += batch_loss\n",
        "        \n",
        "        variables = encoder.variables + decoder.variables\n",
        "        \n",
        "        gradients = tape.gradient(loss, variables)\n",
        "        \n",
        "        optimizer.apply_gradients(zip(gradients, variables))\n",
        "        \n",
        "        if batch % 100 == 0:\n",
        "            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n",
        "                                                         batch,\n",
        "                                                         batch_loss.numpy()))\n",
        "    # saving (checkpoint) the model every 2 epochs\n",
        "    if (epoch + 1) % 2 == 0:\n",
        "      checkpoint.save(file_prefix = checkpoint_prefix)\n",
        "    \n",
        "    print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
        "                                        total_loss / N_BATCH))\n",
        "    print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "mU3Ce8M6I3rz"
      },
      "cell_type": "markdown",
      "source": [
        "## Translate\n",
        "\n",
        "* The evaluate function is similar to the training loop, except we don't use *teacher forcing* here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n",
        "* Stop predicting when the model predicts the *end token*.\n",
        "* And store the *attention weights for every time step*.\n",
        "\n",
        "Note: The encoder output is calculated only once for one input."
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "EbQpyYs13jF_",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n",
        "    attention_plot = np.zeros((max_length_targ, max_length_inp))\n",
        "    \n",
        "    sentence = preprocess_sentence(sentence)\n",
        "\n",
        "    inputs = [inp_lang.word2idx[i] for i in sentence.split(' ')]\n",
        "    inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs], maxlen=max_length_inp, padding='post')\n",
        "    inputs = tf.convert_to_tensor(inputs)\n",
        "    \n",
        "    result = ''\n",
        "\n",
        "    hidden = [tf.zeros((1, units))]\n",
        "    enc_out, enc_hidden = encoder(inputs, hidden)\n",
        "\n",
        "    dec_hidden = enc_hidden\n",
        "    dec_input = tf.expand_dims([targ_lang.word2idx['<start>']], 0)\n",
        "\n",
        "    for t in range(max_length_targ):\n",
        "        predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)\n",
        "        \n",
        "        # storing the attention weights to plot later on\n",
        "        attention_weights = tf.reshape(attention_weights, (-1, ))\n",
        "        attention_plot[t] = attention_weights.numpy()\n",
        "\n",
        "        predicted_id = tf.argmax(predictions[0]).numpy()\n",
        "\n",
        "        result += targ_lang.idx2word[predicted_id] + ' '\n",
        "\n",
        "        if targ_lang.idx2word[predicted_id] == '<end>':\n",
        "            return result, sentence, attention_plot\n",
        "        \n",
        "        # the predicted ID is fed back into the model\n",
        "        dec_input = tf.expand_dims([predicted_id], 0)\n",
        "\n",
        "    return result, sentence, attention_plot"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "s5hQWlbN3jGF",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# function for plotting the attention weights\n",
        "def plot_attention(attention, sentence, predicted_sentence):\n",
        "    fig = plt.figure(figsize=(10,10))\n",
        "    ax = fig.add_subplot(1, 1, 1)\n",
        "    ax.matshow(attention, cmap='viridis')\n",
        "    \n",
        "    fontdict = {'fontsize': 14}\n",
        "    \n",
        "    ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n",
        "    ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n",
        "\n",
        "    plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "sl9zUHzg3jGI",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n",
        "    result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)\n",
        "        \n",
        "    print('Input: {}'.format(sentence))\n",
        "    print('Predicted translation: {}'.format(result))\n",
        "    \n",
        "    attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n",
        "    plot_attention(attention_plot, sentence.split(' '), result.split(' '))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "n250XbnjOaqP"
      },
      "cell_type": "markdown",
      "source": [
        "## Restore the latest checkpoint and test"
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "UJpT9D5_OgP6",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# restoring the latest checkpoint in checkpoint_dir\n",
        "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "WrAM0FDomq3E",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "translate(u'hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "zSx2iM36EZQZ",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "translate(u'esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "A3LLCx3ZE0Ls",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "translate(u'todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "DUQVLVqUE1YW",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# wrong translation\n",
        "translate(u'trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "RTe5P5ioMJwN"
      },
      "cell_type": "markdown",
      "source": [
        "## Next steps\n",
        "\n",
        "* [Download a different dataset](http://www.manythings.org/anki/) to experiment with translations, for example, English to German, or English to French.\n",
        "* Experiment with training on a larger dataset, or using more epochs\n"
      ]
    }
  ]
}
